{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "075668f2-50b7-45d2-8946-ae4a4c022891",
   "metadata": {},
   "source": [
    "# Adaptive Loss Re-Weighting"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a34d6eb-5bc5-49cf-8944-0833c1978684",
   "metadata": {},
   "source": [
    "We will be using the Allen-Cahn Equation to show this."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a6d5583-5583-4aed-960b-c2a1bdddcb42",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import optax\n",
    "from flax import linen as nn\n",
    "\n",
    "import sys\n",
    "import os\n",
    "\n",
    "import time\n",
    "import scipy\n",
    "\n",
    "# Add /src to path\n",
    "path_to_src = os.path.abspath(os.path.join(os.getcwd(), '../../../src'))\n",
    "if path_to_src not in sys.path:\n",
    "    sys.path.append(path_to_src)\n",
    "\n",
    "from KAN import KAN\n",
    "from PIKAN import *\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2726d02c-0c6c-4e2b-b812-5774531ae006",
   "metadata": {},
   "source": [
    "## Collocation Points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09464618-a576-4323-a1f2-8300af3c930d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate Collocation points for PDE\n",
    "N = 2**12\n",
    "collocs = jnp.array(sobol_sample(np.array([0,-1]), np.array([1,1]), N)) # (4096, 2)\n",
    "\n",
    "# Generate Collocation points for BCs\n",
    "N = 2**6\n",
    "\n",
    "BC1_colloc = jnp.array(sobol_sample(np.array([0,-1]), np.array([0,1]), N)) # (64, 2)\n",
    "BC1_data = ((BC1_colloc[:,1]**2)*jnp.cos(jnp.pi*BC1_colloc[:,1])).reshape(-1,1)\n",
    "\n",
    "BC2_colloc = jnp.array(sobol_sample(np.array([0,-1]), np.array([1,-1]), N)) # (64, 2)\n",
    "BC2_data = -jnp.ones(BC2_colloc.shape[0]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "BC3_colloc = jnp.array(sobol_sample(np.array([0,1]), np.array([1,1]), N)) # (64, 2)\n",
    "BC3_data = -jnp.ones(BC3_colloc.shape[0]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "# Create lists for BCs\n",
    "bc_collocs = [BC1_colloc, BC2_colloc, BC3_colloc]\n",
    "bc_data = [BC1_data, BC2_data, BC3_data]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa6d6e6f-a362-4d5e-8d1c-4fb9d0445181",
   "metadata": {},
   "source": [
    "### Loss Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01a502e3-1c4d-4086-8d16-745db2cfbb30",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pde_loss(params, collocs, state):\n",
    "    # Eq. parameter\n",
    "    D = jnp.array(0.001, dtype=float)\n",
    "    c = jnp.array(5.0, dtype=float)\n",
    "    \n",
    "    # Define the model function\n",
    "    variables = {'params' : params, 'state' : state}\n",
    "    \n",
    "    def u(vec_x):\n",
    "        y, spl = model.apply(variables, vec_x)\n",
    "        return y\n",
    "        \n",
    "    # Physics Loss Terms\n",
    "    u_t = gradf(u, 0, 1)  # 1st order derivative of t\n",
    "    u_xx = gradf(u, 1, 2) # 2nd order derivative of x\n",
    "    \n",
    "    # Residual\n",
    "    pde_res = u_t(collocs) - D*u_xx(collocs) - c*(u(collocs)-(u(collocs)**3))\n",
    "    \n",
    "    return pde_res"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3da3da60-3011-4385-9fba-d90efb379aed",
   "metadata": {},
   "source": [
    "## Training with RBA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2823958-13e4-4d65-91af-8d548a582c9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model\n",
    "layer_dims = [2, 8, 8, 1]\n",
    "model = KAN(layer_dims=layer_dims, k=3, const_spl=False, const_res=False, add_bias=True, grid_e=0.05)\n",
    "variables = model.init(jax.random.PRNGKey(0), jnp.ones([1, 2]))\n",
    "\n",
    "# Define learning rates for scheduler\n",
    "lr_vals = dict()\n",
    "lr_vals['init_lr'] = 0.001\n",
    "lr_vals['scales'] = {0 : 1.0, 15_000 : 0.6}\n",
    "\n",
    "# Define epochs for grid adaptation\n",
    "adapt_every = 275\n",
    "adapt_stop = 20000\n",
    "grid_adapt = [i * adapt_every for i in range(1, (adapt_stop // adapt_every) + 1)]\n",
    "\n",
    "# Define epochs for grid extension, along with grid sizes\n",
    "grid_extend = {0 : 3, 8000 : 8}\n",
    "\n",
    "# Define global loss weights\n",
    "glob_w = [jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float)]\n",
    "\n",
    "# Initialize RBA weights\n",
    "loc_w = [jnp.ones((collocs.shape[0],1)), jnp.ones((BC1_colloc.shape[0],1)),\n",
    "         jnp.ones((BC2_colloc.shape[0],1)), jnp.ones((BC3_colloc.shape[0],1))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5873060-ccbc-41b9-8fa6-38709794f903",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 20000\n",
    "\n",
    "model, variables, train_losses = train_PIKAN(model, variables, pde_loss, collocs, bc_collocs, bc_data, glob_w=glob_w, \n",
    "                                             lr_vals=lr_vals, adapt_state=True, loc_w=loc_w, nesterov=True, \n",
    "                                             num_epochs=num_epochs, grid_extend=grid_extend, grid_adapt=grid_adapt, \n",
    "                                             colloc_adapt={'epochs' : []})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4d0a5f1-32d4-4769-90a0-f678550b618a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "3db416d7-9e1b-4d2b-acb7-4108655bc82b",
   "metadata": {},
   "source": [
    "## Training without RBA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4d3b0b5-62a8-4e96-b537-e450a32f56a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model\n",
    "layer_dims = [2, 8, 8, 1]\n",
    "model = KAN(layer_dims=layer_dims, k=3, const_spl=False, const_res=False, add_bias=True, grid_e=0.05)\n",
    "variables = model.init(jax.random.PRNGKey(0), jnp.ones([1, 2]))\n",
    "\n",
    "# Define learning rates for scheduler\n",
    "lr_vals = dict()\n",
    "lr_vals['init_lr'] = 0.001\n",
    "lr_vals['scales'] = {0 : 1.0, 15_000 : 0.6}\n",
    "\n",
    "# Define epochs for grid adaptation\n",
    "adapt_every = 275\n",
    "adapt_stop = 20000\n",
    "grid_adapt = [i * adapt_every for i in range(1, (adapt_stop // adapt_every) + 1)]\n",
    "\n",
    "# Define epochs for grid extension, along with grid sizes\n",
    "grid_extend = {0 : 3, 8000 : 8}\n",
    "\n",
    "# Define global loss weights\n",
    "glob_w = [jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float)]\n",
    "\n",
    "# Initialize RBA weights\n",
    "loc_w = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e95bca8-2ce9-4880-8219-f0802dedb28a",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 20000\n",
    "\n",
    "model, variables, train_losses2 = train_PIKAN(model, variables, pde_loss, collocs, bc_collocs, bc_data, glob_w=glob_w, \n",
    "                                             lr_vals=lr_vals, adapt_state=True, loc_w=loc_w, nesterov=True, \n",
    "                                             num_epochs=num_epochs, grid_extend=grid_extend, grid_adapt=grid_adapt, \n",
    "                                             colloc_adapt={'epochs' : []})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0671672c-6959-4e82-b7e5-f28780321ddc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e5676400-a420-4d94-beef-b4a0b35d1303",
   "metadata": {},
   "source": [
    "## Save Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "904d131e-a0fb-4da8-8fb2-014fed282e25",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = np.arange(num_epochs)\n",
    "np.savez('../Plots/data/rba.npz', epochs=epochs, loss1=train_losses, loss2=train_losses2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c767a689-9345-4897-94d2-62d6d5b99a6c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cae5553b-3754-4809-9fd4-3b869a7fe77c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sample M points from Sobol\n",
    "M = 2**16\n",
    "sample = jnp.array(sobol_sample(np.array([0,-1]), np.array([1,1]), M))\n",
    "# Draw k, c hyperparameters\n",
    "k, c = jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b33eddb-6e8a-4727-984e-94396f86ab78",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
